from distilabel.distiset import Distiset
import openai 
from openai import OpenAI
from datasets import Dataset
import random
import os
import time

def getOutput(prompt:str, max_retries=30) -> str:

    for i in range(max_retries):
        try:
            client = OpenAI(api_key="", base_url="")

            completion = client.chat.completions.create(
                model='deepseek-v3',  
                messages=[
                    {'role': 'system', 'content': 'You are a helpful assistant.'},
                    {'role': 'user', 'content': prompt}
                ],
                temperature = 0
            )
            print(completion.choices[0].message.content)
            output = completion.choices[0].message.content
            #tokens = completion.usage
            return output
        except Exception as e:
            if i == max_retries - 1:  
                raise  
            else:
                # Wait for a bit before retrying and increase the delay each time
                sleep_time = (2 ** i) + random.random()  # Exponential backoff with full jitter
                time.sleep(sleep_time) 



distiset1 = Distiset.load_from_disk(".../SFTData/ToMi_May_test")
print(distiset1)

distiset2 = Distiset.load_from_disk(".../SFTData/HiToM_May_test")
print(distiset2)

distiset3 = Distiset.load_from_disk(".../SFTData/ExploreToM_May_test")
print(distiset3)

distiset4 = Distiset.load_from_disk(".../SFTData/ToMBench_May_test")
print(distiset4)

distiset5 = Distiset.load_from_disk(".../SFTData/SocialIqa_May_1_test")
distiset6 = Distiset.load_from_disk(".../SFTData/SocialIqa_May_2_test")
print(distiset5)
print(distiset6)


dataset1 = distiset1['default']['train']
dataset2 = distiset2['default']['train']
dataset3 = distiset3['default']['train']
dataset4 = distiset4['default']['train']
dataset5 = distiset5['default']['train']
dataset6 = distiset6['default']['train']

print(dataset1)
print(dataset2)
print(dataset3)
print(dataset4)
print(dataset5)
print(dataset6)

filtered_dataset = []
def filter_dataset(dataset, expected_answer_key='answer'):
    for example in dataset:
        message = example['messages']
        #print(message)
        answer = example[expected_answer_key]
        prompt = f"""
        There are messages that contain the user's question and the corresponding response:
        {message}
        This is the correct answer:
        {answer}
        Is the final answer correct? Output 'True' or 'False' only.
        """
        result = getOutput(prompt)
        if 'True' in result:
            dict = {"messages": message}
            filtered_dataset.append(dict)
    return filtered_dataset


filtered1 = filter_dataset(dataset1, expected_answer_key='answer')
filtered2 = filter_dataset(dataset2, expected_answer_key='answer')
filtered3 = filter_dataset(dataset3, expected_answer_key='expected_answer')
filtered4 = filter_dataset(dataset4, expected_answer_key='答案\nANSWER')
filtered5 = filter_dataset(dataset5, expected_answer_key='label_letter')
filtered6 = filter_dataset(dataset6, expected_answer_key='label_letter')


final_dataset = Dataset.from_list(filtered_dataset)
print(final_dataset)

final_dataset.save_to_disk(".../SFTData/all_May_entire_test")
final_dataset.load_from_disk(".../SFTData/all_May_entire_test")
print(final_dataset)
print(final_dataset[0])



